from __future__ import annotations

import math
from copy import deepcopy

import numpy as np
import torch
from isaacgym import gymapi, gymtorch, gymutil  # noqa: F401
from loguru import logger as log

# Optional: RoboSplatter imports for GS background rendering
try:
    from robo_splatter.models.camera import Camera as SplatCamera

    ROBO_SPLATTER_AVAILABLE = True
except ImportError:
    ROBO_SPLATTER_AVAILABLE = False
    log.warning("RoboSplatter not available. GS background rendering will be disabled.")

from metasim.constants import PhysicStateType
from metasim.queries.base import BaseQueryType
from metasim.scenario.objects import (
    ArticulationObjCfg,
    BaseObjCfg,
    PrimitiveCubeCfg,
    PrimitiveSphereCfg,
    RigidObjCfg,
    _FileBasedMixin,
)

# FIXME: fix this
# from metasim.scenario.randomization import FrictionRandomCfg, MassRandomCfg
# NOTE domain randomization for robots
# please refer to roboverse_learn.rl.unitree_rl.config.cfg_randomizer for material and mass randomization for isaacgym and isaacsim
from metasim.scenario.scenario import ScenarioCfg
from metasim.sim import BaseSimHandler
from metasim.types import Action, DictEnvState
from metasim.utils.dict import class_to_dict
from metasim.utils.state import CameraState, ObjectState, RobotState, TensorState


class IsaacgymHandler(BaseSimHandler):
    def __init__(self, scenario: ScenarioCfg, optional_queries: dict[str, BaseQueryType] | None = None):
        super().__init__(scenario, optional_queries)
        self._actions_cache: list[Action] = []
        self._robot_names = {robot.name for robot in self.robots}
        self._robot_init_pos = {robot.name: robot.default_position for robot in self.robots}
        self._robot_init_quat = {robot.name: robot.default_orientation for robot in self.robots}
        self._cameras = scenario.cameras

        self.gym = None
        self.sim = None
        self.viewer = None
        self._enable_viewer_sync: bool = True  # sync viewer flag
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._num_envs: int = scenario.num_envs
        # self._episode_length_buf = [0 for _ in range(self.num_envs)]

        # asset related
        self._asset_dict_dict: dict = {}  # dict of object link index dict
        self._articulated_asset_dict_dict: dict = {}  # dict of articulated object link index dict
        self._articulated_joint_dict_dict: dict = {}  # dict of articulated object joint index dict
        self._robot_link_dict: dict = {}  # dict of robot link index dict
        self._robot_joint_dict: dict = {}  # dict of robot joint index dict
        self._joint_info: dict = {}  # dict of joint names of each env
        self._num_joints: int = 0
        self._body_info: dict = {}  # dict of body names of each env
        self._num_bodies: int = 0

        # environment related pointers
        self._envs: list = []
        self._obj_handles: list = []  # 2 dim list: list in list, each list contains object handles of each env
        self._articulated_obj_handles: list = []  # 2 dim list: list in list, each list contains articulated object handles of each env
        self._robot_handles: list = []  # 2 dim list: list of robot handles of each env

        # environment related tensor indices
        self._env_rigid_body_global_indices: list = []  # 2 dim list: list in list, each list contains global indices of each env

        # will update after refresh
        self._root_states: torch.Tensor | None = None
        self._dof_states: torch.Tensor | None = None
        self._rigid_body_states: torch.Tensor | None = None
        self._robot_dof_state: torch.Tensor | None = None
        self._contact_forces: torch.Tensor | None = None

        # control related
        self._robot_num_dof: int  # number of robot dof
        self._obj_num_dof: int = 0  # number of object dof
        self._actions: torch.Tensor | None = None
        self._action_scale: torch.Tensor | None = (
            None  # for configuration: desire_pos = action_scale * action + default_pos
        )
        self._robot_default_dof_pos: torch.Tensor | None = (
            None  # for the configuration: desire_pos = action_scale * action + default_pos
        )
        self._action_offset: bool = False  # for configuration: desire_pos = action_scale * action + default_pos
        self._p_gains: torch.Tensor | None = None  # parameter for PD controller in for pd effort control
        self._d_gains: torch.Tensor | None = None
        self._torque_limits: torch.Tensor | None = None
        self._effort: torch.Tensor | None = None  # output of pd controller, used for effort control
        self._dof_force: torch.Tensor | None = None  # measured DOF forces from simulator
        self._pos_ctrl_dof_dix = []  # joint index in dof state, built-in position control mode
        self._manual_pd_on: bool = False  # turn on maunual pd controller if effort joint exist

    def launch(self) -> None:
        ## IsaacGym Initialization
        self._init_gym()
        self._make_envs()
        self._set_up_camera()
        if self.scenario.gs_scene is not None and self.scenario.gs_scene.with_gs_background:
            self._build_gs_background()
        # ==== prepare tensors =====
        # from now on, we will use the tensor API that can run on CPU or GPU
        self.gym.prepare_sim(self.sim)
        self._root_states = gymtorch.wrap_tensor(self.gym.acquire_actor_root_state_tensor(self.sim))
        self._dof_states = gymtorch.wrap_tensor(self.gym.acquire_dof_state_tensor(self.sim))
        self._rigid_body_states = gymtorch.wrap_tensor(self.gym.acquire_rigid_body_state_tensor(self.sim))
        # measured per-DOF forces/torques from simulator
        self._dof_force = gymtorch.wrap_tensor(self.gym.acquire_dof_force_tensor(self.sim))
        self._robot_dof_state = self._dof_states.view(self._num_envs, -1, 2)[:, self._obj_num_dof :]
        self._contact_forces = gymtorch.wrap_tensor(self.gym.acquire_net_contact_force_tensor(self.sim))

        # Refresh tensors
        if not self._manual_pd_on:
            self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_mass_matrix_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        # refresh measured dof forces if available
        self.gym.refresh_dof_force_tensor(self.sim)

        # if self.optional_queries is None:
        #     self.optional_queries = {}
        # for query_name, query_type in self.optional_queries.items():
        #     query_type.bind_handler(self)
        return super().launch()

    def _init_gym(self) -> None:
        physics_engine = gymapi.SIM_PHYSX
        self.gym = gymapi.acquire_gym()
        # configure sim
        # TODO move more params into sim_params cfg
        sim_params = gymapi.SimParams()
        sim_params.up_axis = gymapi.UP_AXIS_Z
        sim_params.gravity = gymapi.Vec3(0.0, 0.0, -9.8)
        if self.scenario.sim_params.dt is not None:
            # IsaacGym has a different dt definition than IsaacLab, see https://isaac-sim.github.io/IsaacLab/main/source/migration/migrating_from_isaacgymenvs.html#simulation-config
            sim_params.dt = self.scenario.sim_params.dt
        sim_params.substeps = self.scenario.sim_params.substeps
        sim_params.use_gpu_pipeline = self.scenario.sim_params.use_gpu_pipeline
        sim_params.physx.solver_type = self.scenario.sim_params.solver_type
        sim_params.physx.num_position_iterations = self.scenario.sim_params.num_position_iterations
        sim_params.physx.num_velocity_iterations = self.scenario.sim_params.num_velocity_iterations
        sim_params.physx.rest_offset = self.scenario.sim_params.rest_offset
        sim_params.physx.contact_offset = self.scenario.sim_params.contact_offset
        sim_params.physx.friction_offset_threshold = self.scenario.sim_params.friction_offset_threshold
        sim_params.physx.friction_correlation_distance = self.scenario.sim_params.friction_correlation_distance
        sim_params.physx.num_threads = self.scenario.sim_params.num_threads
        sim_params.physx.use_gpu = self.scenario.sim_params.use_gpu
        sim_params.physx.bounce_threshold_velocity = self.scenario.sim_params.bounce_threshold_velocity
        sim_params.physx.max_depenetration_velocity = self.scenario.sim_params.max_depenetration_velocity
        sim_params.physx.default_buffer_size_multiplier = self.scenario.sim_params.default_buffer_size_multiplier

        compute_device_id = 0
        graphics_device_id = 0
        if self.headless and len(self.cameras) == 0:
            graphics_device_id = -1
        self.sim = self.gym.create_sim(compute_device_id, graphics_device_id, physics_engine, sim_params)
        if self.sim is None:
            raise Exception("Failed to create sim")
        if not self.headless:
            self.viewer = self.gym.create_viewer(self.sim, gymapi.CameraProperties())
            # press 'V' to toggle viewer sync
            self.gym.subscribe_viewer_keyboard_event(self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
            if self.viewer is None:
                raise Exception("Failed to create viewer")

    def _get_camera_params(self, vinv_matrix, proj_matrix, width, height):
        """Get camera intrinsics and extrinsics from IsaacGym matrices.

        Args:
            vinv_matrix: IsaacGym view inverse matrix (4x4 tensor)
            proj_matrix: IsaacGym projection matrix (4x4 tensor)
            width: image width
            height: image height

        Returns:
            Ks: (3, 3) intrinsic matrix
            c2w: (4, 4) camera-to-world transformation matrix
        """
        # Extrinsics: vinv_matrix is already camera-to-world
        # IsaacGym returns row-major matrix; transpose to standard c2w format
        c2w = vinv_matrix.T.to(self.device, dtype=torch.float32)

        # Intrinsics: extract from projection matrix and build torch tensor on device
        fx = proj_matrix[0, 0] * (width / 2.0)
        fy = proj_matrix[1, 1] * (height / 2.0)
        Ks = torch.zeros((3, 3), dtype=torch.float32, device=self.device)
        Ks[0, 0] = fx
        Ks[1, 1] = fy
        Ks[0, 2] = float(width) / 2.0
        Ks[1, 2] = float(height) / 2.0
        Ks[2, 2] = 1.0

        return Ks, c2w

    def _apply_gs_background_rendering(self, camera_states, env_ids):
        """Apply GS background rendering to camera states using pure tensor operations.

        Args:
            camera_states: The camera states
            env_ids: The environment ids

        Returns:
            The camera states with blended results
        """
        from metasim.utils.gs_util import alpha_blend_rgba_torch

        if not ROBO_SPLATTER_AVAILABLE or self.gs_background is None:
            return camera_states

        for cam_id, cam in enumerate(self.cameras):
            # Get camera parameters for the first environment
            env_id = env_ids[0] if env_ids else 0
            vinv_matrix = self._vinv_mats[env_id][cam_id]
            proj_matrix = self._proj_mats[env_id][cam_id]
            width, height = cam.width, cam.height

            # Extract camera parameters
            Ks, c2w = self._get_camera_params(vinv_matrix, proj_matrix, width, height)

            # Render GS background using tensor-native camera init
            gs_cam = SplatCamera.init_from_pose_tensor(
                c2w=c2w,
                Ks=Ks,
                image_height=height,
                image_width=width,
                device=self.device,
            )
            gs_result = self.gs_background.render(gs_cam)

            # Get GS background and normalize to tensors on device (handle numpy or torch)
            gs_rgb = gs_result.rgb[0].to(self.device)
            gs_depth = gs_result.depth[0].to(self.device)

            # Ensure depth is 2D
            if gs_depth.ndim == 3 and gs_depth.shape[-1] == 1:
                gs_depth = gs_depth.squeeze(-1)

            # Get simulation rendering (already tensors)
            sim_rgb = camera_states[cam.name].rgb  # Shape: (num_envs, height, width, 3)
            sim_depth = camera_states[cam.name].depth  # Shape: (num_envs, height, width)

            # Process each environment
            blended_rgb_list = []
            blended_depth_list = []

            for env_idx in range(sim_rgb.shape[0]):
                # Get segmentation mask
                env_id_actual = env_ids[env_idx] if env_idx < len(env_ids) else env_idx
                seg_tensor = self._seg_tensors[env_id_actual][cam_id]

                # Alpha: exclude background (-1) and ground plane (0)
                alpha = ((seg_tensor > -1) & (seg_tensor != 0)).float()

                # Get simulation RGB and depth
                sim_rgb_env = sim_rgb[env_idx].float()
                sim_rgb_env = (sim_rgb_env / 255.0).clamp(0.0, 1.0)
                sim_depth_env = sim_depth[env_idx]

                # Alpha blend using torch
                blended_rgb = alpha_blend_rgba_torch(sim_rgb_env, gs_rgb, alpha)

                blended_rgb = (blended_rgb * 255.0).clamp(0.0, 255.0).to(torch.uint8)

                # Compose depth using boolean mask
                mask_bool = alpha.squeeze(-1) > 0.5
                blended_depth = torch.where(mask_bool, sim_depth_env, gs_depth)

                blended_rgb_list.append(blended_rgb)
                blended_depth_list.append(blended_depth)

            # Update camera state with blended results
            camera_states[cam.name] = CameraState(
                rgb=torch.stack(blended_rgb_list),
                depth=torch.stack(blended_depth_list),
            )

        return camera_states

    def _set_up_camera(self) -> None:
        self._depth_tensors = []
        self._rgb_tensors = []
        self._seg_tensors = []
        self._vinv_mats = []
        self._proj_mats = []
        self._camera_handles = []
        self._env_origin = []
        for i_env in range(self.num_envs):
            self._depth_tensors.append([])
            self._rgb_tensors.append([])
            self._seg_tensors.append([])
            self._vinv_mats.append([])
            self._proj_mats.append([])
            self._env_origin.append([])

            origin = self.gym.get_env_origin(self._envs[i_env])
            self._env_origin[i_env] = [origin.x, origin.y, origin.z]
            for cam_cfg in self.cameras:
                camera_props = gymapi.CameraProperties()
                camera_props.width = cam_cfg.width
                camera_props.height = cam_cfg.height
                camera_props.horizontal_fov = cam_cfg.horizontal_fov
                camera_props.near_plane = cam_cfg.clipping_range[0]
                camera_props.far_plane = cam_cfg.clipping_range[1]
                camera_props.enable_tensors = True
                camera_handle = self.gym.create_camera_sensor(self._envs[i_env], camera_props)

                self._camera_handles.append(camera_handle)

                camera_eye = gymapi.Vec3(*cam_cfg.pos)
                camera_lookat = gymapi.Vec3(*cam_cfg.look_at)
                self.gym.set_camera_location(camera_handle, self._envs[i_env], camera_eye, camera_lookat)
                if cam_cfg.mount_to is not None:
                    if isinstance(cam_cfg.mount_link, str):
                        mount_handle = self._robot_link_dict[cam_cfg.mount_link]
                    elif isinstance(cam_cfg.mount_link, tuple):
                        mount_handle = self._robot_link_dict[cam_cfg.mount_link[1]]
                    camera_pose = gymapi.Transform(
                        gymapi.Vec3(*cam_cfg.mount_pos), gymapi.Quat(*cam_cfg.mount_quat[1:], cam_cfg.mount_quat[0])
                    )
                    self.gym.attach_camera_to_body(
                        camera_handle, self._envs[i_env], mount_handle, camera_pose, gymapi.FOLLOW_TRANSFORM
                    )

                camera_tensor_depth = self.gym.get_camera_image_gpu_tensor(
                    self.sim, self._envs[i_env], camera_handle, gymapi.IMAGE_DEPTH
                )
                camera_tensor_rgb = self.gym.get_camera_image_gpu_tensor(
                    self.sim, self._envs[i_env], camera_handle, gymapi.IMAGE_COLOR
                )
                camera_tensor_rgb_seg = self.gym.get_camera_image_gpu_tensor(
                    self.sim, self._envs[i_env], camera_handle, gymapi.IMAGE_SEGMENTATION
                )
                torch_cam_depth_tensor = gymtorch.wrap_tensor(camera_tensor_depth)
                torch_cam_rgb_tensor = gymtorch.wrap_tensor(camera_tensor_rgb)
                torch_cam_rgb_seg_tensor = gymtorch.wrap_tensor(camera_tensor_rgb_seg)

                cam_vinv = torch.inverse(
                    torch.tensor(self.gym.get_camera_view_matrix(self.sim, self._envs[i_env], camera_handle))
                ).to(self.device)
                cam_proj = torch.tensor(
                    self.gym.get_camera_proj_matrix(self.sim, self._envs[i_env], camera_handle),
                    device=self.device,
                )

                self._depth_tensors[i_env].append(torch_cam_depth_tensor)
                self._rgb_tensors[i_env].append(torch_cam_rgb_tensor)
                self._seg_tensors[i_env].append(torch_cam_rgb_seg_tensor)
                self._vinv_mats[i_env].append(cam_vinv)
                self._proj_mats[i_env].append(cam_proj)

    def _load_object_asset(self, object: BaseObjCfg) -> None:
        asset_root = "."
        if isinstance(object, PrimitiveCubeCfg):
            asset_options = gymapi.AssetOptions()
            asset_options.armature = 0.01
            asset_options.fix_base_link = object.fix_base_link
            asset_options.disable_gravity = not object.enabled_gravity
            asset_options.flip_visual_attachments = False
            asset = self.gym.create_box(self.sim, object.size[0], object.size[1], object.size[2], asset_options)
        elif isinstance(object, PrimitiveSphereCfg):
            asset_options = gymapi.AssetOptions()
            asset_options.armature = 0.01
            asset_options.fix_base_link = object.fix_base_link
            asset_options.disable_gravity = not object.enabled_gravity
            asset_options.flip_visual_attachments = False
            asset = self.gym.create_sphere(self.sim, object.radius, asset_options)

        elif isinstance(object, ArticulationObjCfg):
            asset_path = object.file_name("isaacgym")
            asset_options = gymapi.AssetOptions()
            asset_options.armature = 0.01
            asset_options.fix_base_link = object.fix_base_link
            asset_options.disable_gravity = not object.enabled_gravity
            asset_options.flip_visual_attachments = False
            asset = self.gym.load_asset(self.sim, asset_root, asset_path, asset_options)
            self._articulated_asset_dict_dict[object.name] = self.gym.get_asset_rigid_body_dict(asset)
            self._articulated_joint_dict_dict[object.name] = self.gym.get_asset_dof_dict(asset)
        elif isinstance(object, RigidObjCfg):
            asset_path = object.file_name("isaacgym")
            asset_options = gymapi.AssetOptions()
            # Only set fix_base_link if it's True (non-default)
            if object.fix_base_link:
                asset_options.fix_base_link = True
            # For XFORM physics (goal object), disable gravity
            if hasattr(object, "physics") and object.physics == PhysicStateType.XFORM:
                asset_options.disable_gravity = True
            asset_options.use_mesh_materials = True
            asset_options.collapse_fixed_joints = getattr(object, "collapse_fixed_joints", False)
            asset = self.gym.load_asset(self.sim, asset_root, asset_path, asset_options)

        asset_link_dict = self.gym.get_asset_rigid_body_dict(asset)
        self._asset_dict_dict[object.name] = asset_link_dict
        self._obj_num_dof += self.gym.get_asset_dof_count(asset)
        return asset

    def _load_robot_assets(self) -> None:
        asset_root = "."
        # FIXME: hard code for only one robot
        assert len(self.robots) == 1, "Only support one robot for now"
        robot_asset_file = self.robots[0].file_name("isaacgym")
        asset_options = gymapi.AssetOptions()
        asset_options.armature = getattr(self.robots[0], "armature", 0.01)
        asset_options.fix_base_link = self.robots[0].fix_base_link
        asset_options.disable_gravity = not self.robots[0].enabled_gravity
        asset_options.flip_visual_attachments = self.robots[0].isaacgym_flip_visual_attachments
        asset_options.collapse_fixed_joints = self.robots[0].collapse_fixed_joints
        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        # Defaults are set to free movement and will be updated based on the configuration in actuator_cfg below.
        asset_options.replace_cylinder_with_capsule = self.scenario.sim_params.replace_cylinder_with_capsule
        robot_asset = self.gym.load_asset(self.sim, asset_root, robot_asset_file, asset_options)
        # configure robot dofs
        robot_num_dofs = self.gym.get_asset_dof_count(robot_asset)
        self._robot_num_dof = robot_num_dofs

        # FIXME: hard code for 0-1 action space, should remove all the scale stuff later
        self._action_scale = torch.tensor(1.0, device=self.device)
        self._action_offset = torch.tensor(0.0, device=self.device)

        self._torque_limits = torch.zeros(
            self._num_envs, robot_num_dofs, dtype=torch.float, device=self.device, requires_grad=False
        )

        robot_dof_props = self.gym.get_asset_dof_properties(robot_asset)

        robot_lower_limits = robot_dof_props["lower"]
        robot_upper_limits = robot_dof_props["upper"]
        robot_mids = 0.3 * (robot_upper_limits + robot_lower_limits)
        num_actions = 0
        default_dof_pos = []

        assert self.robots[0].control_type is not None, "Control type is required for robot"
        self._manual_pd_on = any(mode == "effort" for mode in self.robots[0].control_type.values())

        dof_names = self.gym.get_asset_dof_names(robot_asset)
        for i, dof_name in enumerate(dof_names):
            # get config
            i_actuator_cfg = self.robots[0].actuators[dof_name]
            i_control_mode = (
                self.robots[0].control_type[dof_name] if dof_name in self.robots[0].control_type else "position"
            )

            # task default position from cfg if exist, otherwise use 0.3*(uppper + lower) as default
            if not i_actuator_cfg.is_ee:
                default_dof_pos_i = (
                    self.robots[0].default_joint_positions[dof_name]
                    if dof_name in self.robots[0].default_joint_positions
                    else robot_mids[i]
                )
                default_dof_pos.append(default_dof_pos_i)
            # for end effector, always use open as default position
            else:
                default_dof_pos.append(robot_upper_limits[i])
            # pd control effort mode
            if i_control_mode == "effort":
                # FIXME: hard code for 0-1 action space, should remove all the scale stuff later

                robot_dof_props["driveMode"][i] = gymapi.DOF_MODE_EFFORT
                # robot_dof_props["stiffness"][i] = i_actuator_cfg.stiffness
                # robot_dof_props["damping"][i] = i_actuator_cfg.damping
                robot_dof_props["stiffness"][i] = 0.0
                robot_dof_props["damping"][i] = 0.0
                robot_dof_props["armature"][i] = getattr(self.robots[0], "armature", 0.01)

            # built-in position mode
            elif i_control_mode == "position":
                robot_dof_props["driveMode"][i] = gymapi.DOF_MODE_POS
                if i_actuator_cfg.stiffness is not None:
                    robot_dof_props["stiffness"][i] = i_actuator_cfg.stiffness
                else:
                    robot_dof_props["stiffness"][i] = 400.0
                if i_actuator_cfg.damping is not None:
                    robot_dof_props["damping"][i] = i_actuator_cfg.damping
                else:
                    robot_dof_props["damping"][i] = 40.0
                self._pos_ctrl_dof_dix.append(i + self._obj_num_dof)
            else:
                log.error(f"Unknown actuator control mode: {i_control_mode}, only support effort and position")
                raise ValueError

            if i_actuator_cfg.fully_actuated:
                num_actions += 1

        # joint_reindex = self.get_joint_reindex(self.robot.name)
        self._robot_default_dof_pos = torch.tensor(default_dof_pos, device=self.device).unsqueeze(0)
        self.actions = torch.zeros([self._num_envs, num_actions], device=self.device)

        # # get link index of panda hand, which we will use as end effector
        self._robot_link_dict = self.gym.get_asset_rigid_body_dict(robot_asset)
        self._robot_joint_dict = self.gym.get_asset_dof_dict(robot_asset)

        return robot_asset, robot_dof_props

    def _make_envs(
        self,
    ) -> None:
        # configure env grid
        num_per_row = int(math.sqrt(self.num_envs))
        spacing = self.scenario.env_spacing
        env_lower = gymapi.Vec3(-spacing, -spacing, 0.0)
        env_upper = gymapi.Vec3(spacing, spacing, spacing)
        log.info("Creating %d environments" % self.num_envs)

        # FIXME: hard code for only one robot
        assert len(self._robot_init_pos) == 1, "Only support one robot for now"
        robot_pose = gymapi.Transform()
        robot_pose.p = gymapi.Vec3(*self._robot_init_pos[self.robots[0].name])
        robot_pose.r = gymapi.Quat(
            *self._robot_init_quat[self.robots[0].name][1:], self._robot_init_quat[self.robots[0].name][0]
        )  # x, y, z, w order for gymapi.Quat

        # add ground plane
        plane_params = gymapi.PlaneParams()
        plane_params.normal = gymapi.Vec3(0, 0, 1)
        self.gym.add_ground(self.sim, plane_params)

        # get object and robot asset
        obj_assets_list = [self._load_object_asset(obj) for obj in self.objects]
        robot_asset, robot_dof_props = self._load_robot_assets()
        # robot_rigid_shape_props_asset = self.gym.get_asset_rigid_shape_properties(robot_asset)

        #### Joint Info ####
        for art_obj_name, art_obj_joint_dict in self._articulated_joint_dict_dict.items():
            num_joints = len(art_obj_joint_dict)
            joint_names_ = []
            for joint_i in range(num_joints):
                for joint_name, joint_idx in art_obj_joint_dict.items():
                    if joint_idx == joint_i:
                        joint_names_.append(joint_name)
            assert len(joint_names_) == num_joints
            joint_info_ = {}
            joint_info_["names"] = joint_names_
            joint_info_["local_indices"] = art_obj_joint_dict
            art_obj_joint_dict_global = {k_: v_ + self._num_joints for k_, v_ in art_obj_joint_dict.items()}
            joint_info_["global_indices"] = art_obj_joint_dict_global
            self._num_joints += num_joints
            self._joint_info[art_obj_name] = joint_info_

        # robot
        num_joints = len(self._robot_joint_dict)
        joint_names_ = []
        for joint_i in range(num_joints):
            for joint_name, joint_idx in self._robot_joint_dict.items():
                if joint_idx == joint_i:
                    joint_names_.append(joint_name)

        assert len(joint_names_) == num_joints
        joint_info_ = {}
        joint_info_["names"] = joint_names_
        joint_info_["local_indices"] = self._robot_joint_dict
        joint_info_["global_indices"] = {k_: v_ + self._num_joints for k_, v_ in self._robot_joint_dict.items()}
        self._joint_info[self.robots[0].name] = joint_info_
        self._num_joints += num_joints

        ###################
        #### Body Info ####
        for obj_name, asset_dict in self._asset_dict_dict.items():
            num_bodies = len(asset_dict)
            rigid_body_names = []
            for i in range(num_bodies):
                for rigid_body_name, rigid_body_idx in asset_dict.items():
                    if rigid_body_idx == i:
                        rigid_body_names.append(rigid_body_name)
            assert len(rigid_body_names) == num_bodies
            body_info_ = {}
            body_info_["names"] = rigid_body_names
            body_info_["local_indices"] = asset_dict
            body_info_["global_indices"] = {k_: v_ + self._num_bodies for k_, v_ in asset_dict.items()}
            self._body_info[obj_name] = body_info_
            self._num_bodies += num_bodies

        num_bodies = len(self._robot_link_dict)
        rigid_body_names = []
        for i in range(num_bodies):
            for rigid_body_name, rigid_body_idx in self._robot_link_dict.items():
                if rigid_body_idx == i:
                    rigid_body_names.append(rigid_body_name)

        assert len(rigid_body_names) == num_bodies
        rigid_body_info_ = {}
        rigid_body_info_["names"] = rigid_body_names
        rigid_body_info_["local_indices"] = self._robot_link_dict
        rigid_body_info_["global_indices"] = {k_: v_ + self._num_bodies for k_, v_ in self._robot_link_dict.items()}
        self._body_info[self.robots[0].name] = rigid_body_info_
        self._num_bodies += num_bodies

        #################

        for i in range(self.num_envs):
            # create env
            env = self.gym.create_env(self.sim, env_lower, env_upper, num_per_row)

            ##  state update  ##
            self._envs.append(env)
            self._obj_handles.append([])
            self._env_rigid_body_global_indices.append({})
            self._articulated_obj_handles.append([])
            ####################

            # carefully set each object
            for obj_i, obj_asset in enumerate(obj_assets_list):
                # add object
                obj_pose = gymapi.Transform()
                obj = self.objects[obj_i]
                # Use default position from object configuration
                obj_pose.p.x = obj.default_position[0]
                obj_pose.p.y = obj.default_position[1]
                obj_pose.p.z = obj.default_position[2]
                # Use default orientation from object configuration
                obj_pose.r = gymapi.Quat(
                    obj.default_orientation[1],
                    obj.default_orientation[2],
                    obj.default_orientation[3],
                    obj.default_orientation[0],
                )  # x, y, z, w order
                # Create actor with collision group 0 and filter 0 (matches IsaacGymEnvs)
                segmentation_id = obj_i + 1
                obj_handle = self.gym.create_actor(env, obj_asset, obj_pose, obj.name, i, 0, segmentation_id)

                if isinstance(self.objects[obj_i], _FileBasedMixin):
                    self.gym.set_actor_scale(env, obj_handle, self.objects[obj_i].scale[0])
                elif isinstance(self.objects[obj_i], PrimitiveCubeCfg):
                    color = gymapi.Vec3(
                        self.objects[obj_i].color[0],
                        self.objects[obj_i].color[1],
                        self.objects[obj_i].color[2],
                    )
                    self.gym.set_rigid_body_color(env, obj_handle, 0, gymapi.MESH_VISUAL_AND_COLLISION, color)
                elif isinstance(self.objects[obj_i], PrimitiveSphereCfg):
                    color = gymapi.Vec3(
                        self.objects[obj_i].color[0],
                        self.objects[obj_i].color[1],
                        self.objects[obj_i].color[2],
                    )
                    self.gym.set_rigid_body_color(env, obj_handle, 0, gymapi.MESH_VISUAL_AND_COLLISION, color)
                elif isinstance(self.objects[obj_i], RigidObjCfg):
                    pass
                elif isinstance(self.objects[obj_i], ArticulationObjCfg):
                    self._articulated_obj_handles[-1].append(obj_handle)
                else:
                    log.error("Unknown object type")
                    raise NotImplementedError
                self._obj_handles[-1].append(obj_handle)

                object_rigid_body_indices = {}
                for rigid_body_name, local_idx in self._asset_dict_dict[self.objects[obj_i].name].items():
                    global_rigid_body_idx = self.gym.find_actor_rigid_body_index(
                        env, obj_handle, rigid_body_name, gymapi.DOMAIN_SIM
                    )
                    object_rigid_body_indices[rigid_body_name] = global_rigid_body_idx

                self._env_rigid_body_global_indices[-1][self.objects[obj_i].name] = object_rigid_body_indices

            # # carefully add robot
            robot_segmentation_id = len(self.objects) + 1
            _enabled_self_collisions = 0 if self.robots[0].enabled_self_collisions else 2
            robot_handle = self.gym.create_actor(
                env, robot_asset, robot_pose, "robot", i, _enabled_self_collisions, robot_segmentation_id
            )
            assert self.robots[0].scale[0] == 1.0 and self.robots[0].scale[1] == 1.0 and self.robots[0].scale[2] == 1.0
            self._robot_handles.append(robot_handle)
            # set dof properties
            self.gym.set_actor_dof_properties(env, robot_handle, robot_dof_props)

            robot_rigid_body_indices = {}
            for rigid_body_name, rigid_body_idx in self._robot_link_dict.items():
                rigid_body_idx = self.gym.find_actor_rigid_body_index(
                    env, robot_handle, rigid_body_name, gymapi.DOMAIN_SIM
                )
                robot_rigid_body_indices[rigid_body_name] = rigid_body_idx

            self._env_rigid_body_global_indices[-1]["robot"] = robot_rigid_body_indices

            # domain randomization for robots
            # FIXME: add domain randomization with new API
            # self.rand_rigid_body_fric(self.scenario.random.friction, i, robot_rigid_shape_props_asset)
            # robot_body_props = self.gym.get_actor_rigid_body_properties(env, robot_handle)
            # self.rand_rigid_body_mass(self.scenario.random.mass, i, robot_body_props)

        # GET initial state, copy for reset later
        self._initial_state = np.copy(self.gym.get_sim_rigid_body_states(self.sim, gymapi.STATE_ALL))

        ###### set VEWIER camera ######
        # point camera at middle env
        if not self.headless:  # TODO: update a default viewer
            cam_pos = gymapi.Vec3(1, 1, 1)
            cam_target = gymapi.Vec3(-1, -1, -0.5)
            middle_env = self._envs[self.num_envs // 2 + num_per_row // 2]
            self.gym.viewer_camera_look_at(self.viewer, middle_env, cam_pos, cam_target)
        ################################

    def _reorder_quat_xyzw_to_wxyz(self, state: torch.Tensor, reverse: bool = False) -> torch.Tensor:
        return (
            state[..., [0, 1, 2, 4, 5, 6, 3, 7, 8, 9, 10, 11, 12]]
            if reverse
            else state[..., [0, 1, 2, 6, 3, 4, 5, 7, 8, 9, 10, 11, 12]]
        )

    def _get_states(self, env_ids: list[int] | None = None) -> list[DictEnvState]:
        if env_ids is None:
            env_ids = list(range(self.num_envs))
        object_states = {}
        for obj_id, obj in enumerate(self.objects):
            if isinstance(obj, ArticulationObjCfg):
                joint_ids_reindex = self._get_joint_ids_reindex(obj.name)
                body_ids_reindex = self._get_body_ids_reindex(obj.name)
                root_state = self._root_states.view(self.num_envs, -1, 13)[:, obj_id, :]
                root_state = self._reorder_quat_xyzw_to_wxyz(root_state)
                body_state = self._rigid_body_states.view(self.num_envs, -1, 13)[:, body_ids_reindex, :]
                body_state = self._reorder_quat_xyzw_to_wxyz(body_state)
                state = ObjectState(
                    root_state=root_state,
                    body_names=self._get_body_names(obj.name),
                    body_state=body_state,
                    joint_pos=self._dof_states.view(self.num_envs, -1, 2)[:, joint_ids_reindex, 0],
                    joint_vel=self._dof_states.view(self.num_envs, -1, 2)[:, joint_ids_reindex, 1],
                )
            else:
                root_state = self._root_states.view(self.num_envs, -1, 13)[:, obj_id, :]
                root_state = self._reorder_quat_xyzw_to_wxyz(root_state)
                state = ObjectState(
                    root_state=root_state,
                )
            object_states[obj.name] = state

        robot_states = {}
        for robot_id, robot in enumerate(self.robots):
            joint_ids_reindex = self._get_joint_ids_reindex(robot.name)
            body_ids_reindex = self._get_body_ids_reindex(robot.name)
            root_state = self._root_states.view(self.num_envs, -1, 13)[:, len(self.objects) + robot_id, :]
            root_state = self._reorder_quat_xyzw_to_wxyz(root_state)
            body_state = self._rigid_body_states.view(self.num_envs, -1, 13)[:, body_ids_reindex, :]
            body_state = self._reorder_quat_xyzw_to_wxyz(body_state)

            state = RobotState(
                root_state=root_state,
                body_names=self._get_body_names(robot.name),
                body_state=body_state,
                joint_pos=self._dof_states.view(self.num_envs, -1, 2)[:, joint_ids_reindex, 0],
                joint_vel=self._dof_states.view(self.num_envs, -1, 2)[:, joint_ids_reindex, 1],
                joint_pos_target=None,  # TODO
                joint_vel_target=None,  # TODO
                # joint_effort_target=self._effort if self._manual_pd_on else None,
                # prefer measured forces from simulator over internal PD effort
                joint_effort_target=self._dof_force.view(self.num_envs, -1)[:, joint_ids_reindex],
            )
            # FIXME a temporary solution for accessing net contact forces of robots, it will be moved to
            extra = {
                "contact_forces": self._contact_forces.view(self.num_envs, -1, 3)[:, body_ids_reindex, :],
            }
            state.extra = extra
            robot_states[robot.name] = state

        camera_states = {}

        self.refresh_render()
        self.gym.start_access_image_tensors(self.sim)

        for cam_id, cam in enumerate(self.cameras):
            state = CameraState(
                rgb=torch.stack([self._rgb_tensors[env_id][cam_id][..., :3] for env_id in env_ids]),
                depth=-torch.stack([self._depth_tensors[env_id][cam_id] for env_id in env_ids]),  # -z
            )
            camera_states[cam.name] = state
        self.gym.end_access_image_tensors(self.sim)

        # Apply GS background rendering if enabled
        # TODO: Render with batch parallelization for efficiency
        if (
            self.scenario.gs_scene is not None
            and self.scenario.gs_scene.with_gs_background
            and self.gs_background is not None
        ):
            camera_states = self._apply_gs_background_rendering(camera_states, env_ids)

        extras = self.get_extra()  # extra observations
        return TensorState(objects=object_states, robots=robot_states, cameras=camera_states, extras=extras)

    ############################################################
    ## Gymnasium main methods
    ############################################################
    def _get_action_array_all(self, actions: list[Action]):
        action_array_list = []

        for action_data in actions:
            flat_vals = []
            for joint_i, joint_name in enumerate(self._joint_info[self.robots[0].name]["names"]):
                if self.robots[0].actuators[joint_name].fully_actuated:
                    flat_vals.append(
                        action_data[self.robots[0].name]["dof_pos_target"][joint_name]
                    )  # TODO: support other actions
                else:
                    flat_vals.append(0.0)  # place holder for under-actuated joints

            action_array = torch.tensor(flat_vals, dtype=torch.float32, device=self.device).unsqueeze(0)

            action_array_list.append(action_array)
        action_array_all = torch.cat(action_array_list, dim=0)
        return action_array_all

    def set_dof_targets(self, actions: list[Action] | torch.Tensor):
        self._actions_cache = actions
        action_input = torch.zeros_like(self._dof_states[:, 0])
        if isinstance(actions, torch.Tensor):
            # reverse sorted joint indices
            # TODO: support multiple robots
            reverse_reindex = self.get_joint_reindex(obj_name=self.robots[0].name, inverse=True)
            self._actions_cache = actions[:, reverse_reindex]
            action_array_all = self._actions_cache

        else:
            action_array_all = self._get_action_array_all(actions)

        assert (
            action_input.shape[0] % self._num_envs == 0
        )  # WARNING: obj dim(env0), robot dim(env0), obj dim(env1), robot dim(env1) ...

        if not hasattr(self, "_robot_dim_index"):
            robot_dim = action_array_all.shape[1]
            chunk_size = action_input.shape[0] // self._num_envs
            self._robot_dim_index = [
                i * chunk_size + offset
                for i in range(self.num_envs)
                for offset in range(chunk_size - robot_dim, chunk_size)
            ]
        action_input[self._robot_dim_index] = action_array_all.float().to(self.device).reshape(-1)

        # if any effort joint exist, set pd controller's target position for later effort calculation
        if self._manual_pd_on:
            actions_reshape = action_input.view(self._num_envs, self._obj_num_dof + self._robot_num_dof)
            self.actions = actions_reshape[:, self._obj_num_dof :]
            self.gym.set_dof_actuation_force_tensor(self.sim, gymtorch.unwrap_tensor(actions_reshape))
            # and set position target for position actuator if any exist
            if len(self._pos_ctrl_dof_dix) > 0:
                self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(action_input))

        # directly set position target
        else:
            self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(action_input))

    def set_actions(self, actions: torch.Tensor) -> None:
        action_input = torch.zeros_like(self._dof_states[:, 0])

        if not hasattr(self, "_robot_dim_index"):
            chunk = action_input.shape[0] // self._num_envs
            robot_dim = actions.shape[1]
            self._robot_dim_index = [
                env * chunk + (chunk - robot_dim) + i for env in range(self._num_envs) for i in range(robot_dim)
            ]

        action_input[self._robot_dim_index] = actions.to(self.device).reshape(-1)

        if self._manual_pd_on:
            self.actions = action_input.view(self._num_envs, self._obj_num_dof + self._robot_num_dof)[
                :, self._obj_num_dof :
            ]
            if self._pos_ctrl_dof_dix:
                self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(action_input))
        else:
            self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(action_input))

    def refresh_render(self) -> None:
        # Step the physics
        self.gym.simulate(self.sim)
        self.gym.fetch_results(self.sim, True)
        self._render()

    def _simulate_one_physics_step(self):
        self.gym.simulate(self.sim)
        if self.device == "cpu":
            self.gym.fetch_results(self.sim, True)
        self.gym.refresh_dof_state_tensor(self.sim)

    def _simulate(self) -> None:
        # Step the physics
        for _ in range(self.decimation):
            self._simulate_one_physics_step()
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_mass_matrix_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.refresh_dof_force_tensor(self.sim)
        # Refresh cameras and viewer
        self._render()

    def _render(self) -> None:
        """Listen for keyboard events, step graphics and render the environment"""
        if not self.headless:
            for evt in self.gym.query_viewer_action_events(self.viewer):
                if evt.action == "toggle_viewer_sync" and evt.value > 0:
                    self._enable_viewer_sync = not self._enable_viewer_sync
        if self._enable_viewer_sync or len(self.cameras) > 0:
            self.gym.step_graphics(self.sim)
            if len(self.cameras) > 0:
                self.gym.render_all_camera_sensors(self.sim)
            if self._enable_viewer_sync:
                self.gym.draw_viewer(self.viewer, self.sim, False)
        else:
            if not self.headless:
                self.gym.poll_viewer_events(self.viewer)

    def _set_states(self, states: list[DictEnvState] | torch.Tensor, env_ids: list[int] | None = None):
        ## Support setting status only for specified env_ids
        if env_ids is None:
            env_ids = list(range(self.num_envs))

        # if states is list[DictEnvState], iterate over it and set state
        if isinstance(states, list):
            pos_list = []
            rot_list = []
            q_list = []
            states_flat = [{**states[i]["objects"], **states[i]["robots"]} for i in env_ids]

            # Prepare state data for specified env_ids
            env_indices = {env_id: i for i, env_id in enumerate(env_ids)}

            for i in range(self.num_envs):
                if i not in env_indices:
                    continue

                state_idx = env_indices[i]
                state = states_flat[state_idx]

                pos_list_i = []
                rot_list_i = []
                q_list_i = []
                for obj in self.objects:
                    obj_name = obj.name
                    pos = np.array(state[obj_name].get("pos", [0.0, 0.0, 0.0]))
                    rot = np.array(state[obj_name].get("rot", [1.0, 0.0, 0.0, 0.0]))
                    obj_quat = [rot[1], rot[2], rot[3], rot[0]]  # IsaacGym convention

                    pos_list_i.append(pos)
                    rot_list_i.append(obj_quat)
                    if isinstance(obj, ArticulationObjCfg):
                        obj_joint_q = np.zeros(len(self._articulated_joint_dict_dict[obj_name]))
                        articulated_joint_dict = self._articulated_joint_dict_dict[obj_name]
                        for joint_name, joint_idx in articulated_joint_dict.items():
                            if "dof_pos" in state[obj_name]:
                                obj_joint_q[joint_idx] = state[obj_name]["dof_pos"][joint_name]
                            else:
                                log.warning(f"No dof_pos for {joint_name} in {obj_name}")
                                obj_joint_q[joint_idx] = 0.0
                        q_list_i.append(obj_joint_q)

                pos_list_i.append(np.array(state[self.robots[0].name].get("pos", [0.0, 0.0, 0.0])))
                rot = np.array(state[self.robots[0].name].get("rot", [1.0, 0.0, 0.0, 0.0]))
                robot_quat = [rot[1], rot[2], rot[3], rot[0]]
                rot_list_i.append(robot_quat)

                robot_dof_state_i = np.zeros(len(self._robot_joint_dict))
                if "dof_pos" in state[self.robots[0].name]:
                    for joint_name, joint_idx in self._robot_joint_dict.items():
                        robot_dof_state_i[joint_idx] = state[self.robots[0].name]["dof_pos"][joint_name]
                else:
                    for joint_name, joint_idx in self._robot_joint_dict.items():
                        robot_dof_state_i[joint_idx] = (
                            self.robots[0].joint_limits[joint_name][0] + self.robots[0].joint_limits[joint_name][1]
                        ) / 2

                q_list_i.append(robot_dof_state_i)
                pos_list.append(pos_list_i)
                rot_list.append(rot_list_i)
                q_list.append(q_list_i)

            self._set_actor_root_state(pos_list, rot_list, env_ids)
            self._set_actor_joint_state(q_list, env_ids)

        # if states is TensorState, reindex the tensors and set state
        elif isinstance(states, TensorState):
            new_root_states = self._root_states.clone()
            new_dof_states = self._dof_states.clone()

            # Calculate actor indices that need to be updated
            actor_indices = []
            for env_id in env_ids:
                env_offset = env_id * (len(self.objects) + len(self.robots))
                # Set object states
                for idx, obj in enumerate(self.objects):
                    obj_state = states.objects[obj.name]
                    root_state = self._reorder_quat_xyzw_to_wxyz(obj_state.root_state, reverse=True)
                    actor_idx = env_offset + idx
                    new_root_states[actor_idx, :] = root_state[env_id, :]
                    actor_indices.append(actor_idx)

                    # Only articulated objects have DOF states
                    if isinstance(obj, ArticulationObjCfg):
                        joint_pos = obj_state.joint_pos
                        joint_vel = obj_state.joint_vel
                        # Get global joint indices for this object (already in sorted order)
                        joint_ids_global = self._get_joint_ids_reindex(obj.name)
                        # Set DOF states - convert to linear indices
                        # Note: joint_pos/vel from get_states are already in sorted order
                        for local_j_idx, global_j_idx in enumerate(joint_ids_global):
                            # Calculate linear index in flattened DOF tensor
                            dof_linear_idx = env_id * self._num_joints + global_j_idx
                            new_dof_states[dof_linear_idx, 0] = joint_pos[env_id, local_j_idx]
                            new_dof_states[dof_linear_idx, 1] = joint_vel[env_id, local_j_idx]

                # Set robot states
                for idx, robot in enumerate(self.robots):
                    robot_state = states.robots[robot.name]
                    root_state = self._reorder_quat_xyzw_to_wxyz(robot_state.root_state, reverse=True)
                    actor_idx = env_offset + len(self.objects) + idx
                    new_root_states[actor_idx, :] = root_state[env_id, :]
                    actor_indices.append(actor_idx)
                    joint_pos = robot_state.joint_pos
                    joint_vel = robot_state.joint_vel
                    # Get global joint indices for robot (already in sorted order)
                    joint_ids_global = self._get_joint_ids_reindex(robot.name)
                    # Set DOF states - convert to linear indices
                    # Note: joint_pos/vel from get_states are already in sorted order
                    for local_j_idx, global_j_idx in enumerate(joint_ids_global):
                        # Calculate linear index in flattened DOF tensor
                        dof_linear_idx = env_id * self._num_joints + global_j_idx
                        new_dof_states[dof_linear_idx, 0] = joint_pos[env_id, local_j_idx]
                        new_dof_states[dof_linear_idx, 1] = joint_vel[env_id, local_j_idx]

            # Convert the actor indices to a tensor
            actor_indices_tensor = torch.tensor(actor_indices, dtype=torch.int32, device=self.device)

            # Use indexed setting to set the root state
            self.gym.set_actor_root_state_tensor_indexed(
                self.sim,
                gymtorch.unwrap_tensor(new_root_states),
                gymtorch.unwrap_tensor(actor_indices_tensor),
                len(actor_indices),
            )

            # Use set_dof_state_tensor (not indexed) as we've already modified the specific DOF indices
            self.gym.set_dof_state_tensor(
                self.sim,
                gymtorch.unwrap_tensor(new_dof_states),
            )
        else:
            raise Exception("Unsupported state type, must be DictEnvState or TensorState")

        # Refresh tensors
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_mass_matrix_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)

        # reset all env_id action to default
        self.actions[env_ids] = 0.0

    def _set_actor_root_state(self, position_list, rotation_list, env_ids):
        new_root_states = self._root_states.clone()
        actor_indices = []

        # Only modify the positions and rotations for the specified env_ids
        for i, env_id in enumerate(env_ids):
            env_offset = env_id * (len(self.objects) + 1)  # objects + robot
            for j in range(len(self.objects) + 1):
                actor_idx = env_offset + j
                new_root_states[actor_idx, :3] = torch.tensor(
                    position_list[i][j], dtype=torch.float32, device=self.device
                )
                new_root_states[actor_idx, 3:7] = torch.tensor(
                    rotation_list[i][j], dtype=torch.float32, device=self.device
                )
                new_root_states[actor_idx, 7:13] = torch.zeros(6, dtype=torch.float32, device=self.device)
            actor_indices.extend(range(env_offset, env_offset + len(self.objects) + 1))

        # Convert the actor indices to a tensor
        root_reset_actors_indices = torch.tensor(actor_indices, dtype=torch.int32, device=self.device)

        # Use indexed setting to set the root state
        res = self.gym.set_actor_root_state_tensor_indexed(
            self.sim,
            gymtorch.unwrap_tensor(new_root_states),
            gymtorch.unwrap_tensor(root_reset_actors_indices),
            len(root_reset_actors_indices),
        )
        assert res

        return

    def _set_actor_joint_state(self, joint_pos_list, env_ids):
        new_dof_states = self._dof_states.clone()

        # Calculate the indices of DOFs in the tensor
        dof_indices = []
        new_dof_pos_values = []

        for i, env_id in enumerate(env_ids):
            # Get the joint positions for this environment
            flat_vals = []
            for obj_joints in joint_pos_list[i]:
                flat_vals.extend(obj_joints)

            # Calculate the indices of DOFs in the global DOF tensor
            dof_start_idx = env_id * self._num_joints
            for j, val in enumerate(flat_vals):
                dof_idx = dof_start_idx + j
                dof_indices.append(dof_idx)
                new_dof_pos_values.append(val)

        # Update the DOF positions for the specified indices
        dof_indices_tensor = torch.tensor(dof_indices, dtype=torch.int64, device=self.device)
        new_dof_pos_tensor = torch.tensor(new_dof_pos_values, dtype=torch.float32, device=self.device)

        # Update the positions and velocities (set velocities to 0)
        new_dof_states[dof_indices_tensor, 0] = new_dof_pos_tensor
        new_dof_states[dof_indices_tensor, 1] = 0.0

        # Apply the updated DOF state
        res = self.gym.set_dof_state_tensor(self.sim, gymtorch.unwrap_tensor(new_dof_states))
        assert res

        return

    def close(self) -> None:
        try:
            self.gym.destroy_sim(self.sim)
            self.gym.destroy_viewer(self.viewer)
            self.gym = None
            self.sim = None
            self.viewer = None
        except Exception as e:
            log.error(f"Error closing IsaacGym environment: {e}")
            pass

    ############################################################
    ## Utils
    ############################################################
    def _get_joint_names(self, obj_name: str, sort: bool = True) -> list[str]:
        if isinstance(self.object_dict[obj_name], ArticulationObjCfg):
            joint_names = deepcopy(list(self._joint_info[obj_name]["names"]))
            if sort:
                joint_names.sort()
            return joint_names
        else:
            return []

    def _get_body_names(self, obj_name: str, sort: bool = True) -> list[str]:
        if isinstance(self.object_dict[obj_name], ArticulationObjCfg):
            body_names = deepcopy(self._body_info[obj_name]["names"])
            if sort:
                body_names.sort()
            return body_names
        else:
            return []

    def _get_body_ids_reindex(self, obj_name: str) -> list[int]:
        return [self._body_info[obj_name]["global_indices"][bn] for bn in self._get_body_names(obj_name)]

    def _get_joint_ids_reindex(self, obj_name: str) -> list[int]:
        return [self._joint_info[obj_name]["global_indices"][jn] for jn in self._get_joint_names(obj_name)]

    def rand_rigid_body_fric(self, cfg, env_id: int, props: list[gymapi.RigidShapeProperties]):
        """Randomize the friction of the rigid bodies."""
        if not cfg.enabled:
            return
        if not hasattr(self, "_rand_fric_dist"):
            params_dict = class_to_dict(cfg)
            params_dict["num_envs"] = self.num_envs
            params_dict["device"] = self.device
            dist_fn = cfg.dist_fn
            self._rand_fric_dist = dist_fn(params_dict)
        # TODO: add rigid body id index
        for s in range(len(props)):
            props[s].friction = self._rand_fric_dist[env_id]
        return props

    def rand_rigid_body_mass(self, cfg, env_id: int, props: list[gymapi.RigidBodyProperties]):
        """Randomize the base mass."""
        if not cfg.enabled:
            return
        if not hasattr(self, "_rand_mass_dist"):
            params_dict = class_to_dict(cfg)
            params_dict["num_envs"] = self.num_envs
            params_dict["device"] = self.device
            dist_fn = cfg.dist_fn
            self._rand_mass_dist = dist_fn(params_dict)
        # TODO: add rigid body id index
        props[0].mass += self._rand_mass_dist[env_id]
        return props

    @property
    def num_envs(self) -> int:
        return self._num_envs

    @property
    def actions_cache(self) -> list[Action]:
        return self._actions_cache

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def default_dof_pos(self) -> torch.tensor:
        joint_reindex = self.get_joint_reindex(self.robot.name)
        return self._robot_default_dof_pos[:, joint_reindex]

    @property
    def torque_limits(self) -> torch.tensor:
        joint_reindex = self.get_joint_reindex(self.robot.name)
        return self._torque_limits[:, joint_reindex]

    @property
    def robot_num_dof(self) -> int:
        return self._robot_num_dof


# TODO: try to align handler API and use GymWrapper instead
# IsaacgymEnv: type[EnvWrapper[IsaacgymHandler]] = GymEnvWrapper(IsaacgymHandler)
